{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CS 224N Lecture 3: Word Window Classification\n",
    "\n",
    "### Pytorch Exploration\n",
    "\n",
    "### Author: Matthew Lamm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pprint\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "pp = pprint.PrettyPrinter()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Our Data\n",
    "\n",
    "The task at hand is to assign a label of 1 to words in a sentence that correspond with a LOCATION, and a label of 0 to everything else. \n",
    "\n",
    "In this simplified example, we only ever see spans of length 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_sents = [s.lower().split() for s in [\"we 'll always have Paris\",\n",
    "                                           \"I live in Germany\",\n",
    "                                           \"He comes from Denmark\",\n",
    "                                           \"The capital of Denmark is Copenhagen\"]]\n",
    "train_labels = [[0, 0, 0, 0, 1],\n",
    "                [0, 0, 0, 1],\n",
    "                [0, 0, 0, 1],\n",
    "                [0, 0, 0, 1, 0, 1]]\n",
    "\n",
    "assert all([len(train_sents[i]) == len(train_labels[i]) for i in range(len(train_sents))])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_sents = [s.lower().split() for s in [\"She comes from Paris\"]]\n",
    "test_labels = [[0, 0, 0, 1]]\n",
    "\n",
    "assert all([len(test_sents[i]) == len(test_labels[i]) for i in range(len(test_sents))])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Creating a dataset of batched tensors."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PyTorch (like other deep learning frameworks) is optimized to work on __tensors__, which can be thought of as a generalization of vectors and matrices with arbitrarily large rank.\n",
    "\n",
    "Here well go over how to translate data to a list of vocabulary indices, and how to construct *batch tensors* out of the data for easy input to our model. \n",
    "\n",
    "We'll use the *torch.utils.data.DataLoader* object handle ease of batching and iteration."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Converting tokenized sentence lists to vocabulary indices.\n",
    "\n",
    "Let's assume we have the following vocabulary:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "id_2_word = [\"<pad>\", \"<unk>\", \"we\", \"always\", \"have\", \"paris\",\n",
    "              \"i\", \"live\", \"in\", \"germany\",\n",
    "              \"he\", \"comes\", \"from\", \"denmark\",\n",
    "              \"the\", \"of\", \"is\", \"copenhagen\"]\n",
    "word_2_id = {w:i for i,w in enumerate(id_2_word)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['we', \"'ll\", 'always', 'have', 'paris']\n"
     ]
    }
   ],
   "source": [
    "instance = train_sents[0]\n",
    "print(instance)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_tokens_to_inds(sentence, word_2_id):\n",
    "    return [word_2_id.get(t, word_2_id[\"<unk>\"]) for t in sentence]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, 1, 3, 4, 5]\n"
     ]
    }
   ],
   "source": [
    "token_inds = convert_tokens_to_inds(instance, word_2_id)\n",
    "pp.pprint(token_inds)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's convince ourselves that worked:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['we', '<unk>', 'always', 'have', 'paris']\n"
     ]
    }
   ],
   "source": [
    "print([id_2_word[tok_idx] for tok_idx in token_inds])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Padding for windows."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the word window classifier, for each word in the sentence we want to get the +/- n window around the word, where 0 <= n < len(sentence).\n",
    "\n",
    "In order for such windows to be defined for words at the beginning and ends of the sentence, we actually want to insert padding around the sentence before converting to indices:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pad_sentence_for_window(sentence, window_size, pad_token=\"<pad>\"):\n",
    "    return [pad_token]*window_size + sentence + [pad_token]*window_size "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['<pad>', '<pad>', 'we', \"'ll\", 'always', 'have', 'paris', '<pad>', '<pad>']\n"
     ]
    }
   ],
   "source": [
    "window_size = 2\n",
    "instance = pad_sentence_for_window(train_sents[0], window_size)\n",
    "print(instance)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's make sure this works with our vocabulary:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['<pad>', '<pad>', 'we', '<unk>', 'always', 'have', 'paris', '<pad>', '<pad>']\n",
      "['<pad>', '<pad>', 'i', 'live', 'in', 'germany', '<pad>', '<pad>']\n",
      "['<pad>', '<pad>', 'he', 'comes', 'from', 'denmark', '<pad>', '<pad>']\n",
      "['<pad>', '<pad>', 'the', '<unk>', 'of', 'denmark', 'is', 'copenhagen', '<pad>', '<pad>']\n"
     ]
    }
   ],
   "source": [
    "for sent in train_sents:\n",
    "    tok_idxs = convert_tokens_to_inds(pad_sentence_for_window(sent, window_size), word_2_id)\n",
    "    print([id_2_word[idx] for idx in tok_idxs])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Batching sentences together with a DataLoader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When we train our model, we rarely update with respect to a single training instance at a time, because a single instance provides a very noisy estimate of the global loss's gradient. We instead construct small *batches* of data, and update parameters for each batch. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Given some batch size, we want to construct batch tensors out of the word index lists we've just created with our vocab.\n",
    "\n",
    "For each length B list of inputs, we'll have to:\n",
    "\n",
    "    (1) Add window padding to sentences in the batch like we just saw.\n",
    "    (2) Add additional padding so that each sentence in the batch is the same length.\n",
    "    (3) Make sure our labels are in the desired format.\n",
    "\n",
    "At the level of the dataest we want:\n",
    "\n",
    "    (4) Easy shuffling, because shuffling from one training epoch to the next gets rid of \n",
    "        pathological batches that are tough to learn from.\n",
    "    (5) Making sure we shuffle inputs and their labels together!\n",
    "    \n",
    "PyTorch provides us with an object *torch.utils.data.DataLoader* that gets us (4) and (5). All that's required of us is to specify a *collate_fn* that tells it how to do (1), (2), and (3). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('raw train label instance', tensor([0, 0, 0, 0, 1]))\n",
      "torch.Size([5])\n"
     ]
    }
   ],
   "source": [
    "l = torch.LongTensor(train_labels[0])\n",
    "pp.pprint((\"raw train label instance\", l))\n",
    "print(l.size())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('unfilled label instance',\n",
      " tensor([[0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 0.]]))\n",
      "torch.Size([2, 5])\n"
     ]
    }
   ],
   "source": [
    "one_hots = torch.zeros((2, len(l)))\n",
    "pp.pprint((\"unfilled label instance\", one_hots))\n",
    "print(one_hots.size())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('one-hot labels', tensor([[0., 0., 0., 0., 0.],\n",
      "        [0., 0., 0., 0., 1.]]))\n"
     ]
    }
   ],
   "source": [
    "one_hots[1] = l\n",
    "pp.pprint((\"one-hot labels\", one_hots))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('one-hot labels', tensor([[1., 1., 1., 1., 0.],\n",
      "        [0., 0., 0., 0., 1.]]))\n"
     ]
    }
   ],
   "source": [
    "l_not = ~l.byte()\n",
    "one_hots[0] = l_not\n",
    "pp.pprint((\"one-hot labels\", one_hots))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from functools import partial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "def my_collate(data, window_size, word_2_id):\n",
    "    \"\"\"\n",
    "    For some chunk of sentences and labels\n",
    "        -add winow padding\n",
    "        -pad for lengths using pad_sequence\n",
    "        -convert our labels to one-hots\n",
    "        -return padded inputs, one-hot labels, and lengths\n",
    "    \"\"\"\n",
    "    \n",
    "    x_s, y_s = zip(*data)\n",
    "\n",
    "    # deal with input sentences as we've seen\n",
    "    window_padded = [convert_tokens_to_inds(pad_sentence_for_window(sentence, window_size), word_2_id)\n",
    "                                                                                  for sentence in x_s]\n",
    "    # append zeros to each list of token ids in batch so that they are all the same length\n",
    "    padded = nn.utils.rnn.pad_sequence([torch.LongTensor(t) for t in window_padded], batch_first=True)\n",
    "    \n",
    "    # convert labels to one-hots\n",
    "    labels = []\n",
    "    lengths = []\n",
    "    for y in y_s:\n",
    "        lengths.append(len(y))\n",
    "        label = torch.zeros((len(y),2 ))\n",
    "        true = torch.LongTensor(y) \n",
    "        false = ~true.byte()\n",
    "        label[:, 0] = false\n",
    "        label[:, 1] = true\n",
    "        labels.append(label)\n",
    "    padded_labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)\n",
    "    \n",
    "    return padded.long(), padded_labels, torch.LongTensor(lengths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Shuffle True is good practice for train loaders.\n",
    "# Use functools.partial to construct a partially populated collate function\n",
    "example_loader = DataLoader(list(zip(train_sents, \n",
    "                                                      train_labels)), \n",
    "                                             batch_size=2, \n",
    "                                             shuffle=True, \n",
    "                                             collate_fn=partial(my_collate, window_size=2, word_2_id=word_2_id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('inputs',\n",
      " tensor([[ 0,  0,  2,  1,  3,  4,  5,  0,  0],\n",
      "        [ 0,  0, 10, 11, 12, 13,  0,  0,  0]]),\n",
      " torch.Size([2, 9]))\n",
      "('labels',\n",
      " tensor([[[1., 0.],\n",
      "         [1., 0.],\n",
      "         [1., 0.],\n",
      "         [1., 0.],\n",
      "         [0., 1.]],\n",
      "\n",
      "        [[1., 0.],\n",
      "         [1., 0.],\n",
      "         [1., 0.],\n",
      "         [0., 1.],\n",
      "         [0., 0.]]]),\n",
      " torch.Size([2, 5, 2]))\n",
      "tensor([5, 4])\n"
     ]
    }
   ],
   "source": [
    "for batched_input, batched_labels, batch_lengths in example_loader:\n",
    "    pp.pprint((\"inputs\", batched_input, batched_input.size()))\n",
    "    pp.pprint((\"labels\", batched_labels, batched_labels.size()))\n",
    "    pp.pprint(batch_lengths)\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Modeling\n",
    "\n",
    "### Thinking through vectorization of word windows.\n",
    "Before we go ahead and build our model, let's think about the first thing it needs to do to its inputs.\n",
    "\n",
    "We're passed batches of sentences. For each sentence i in the batch, for each word j in the sentence, we want to construct a single tensor out of the embeddings surrounding word j in the +/- n window.\n",
    "\n",
    "Thus, the first thing we're going to need a (B, L, 2N+1) tensor of token indices."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A *terrible* but nevertheless informative *iterative* solution looks something like the following, where we iterate through batch elements in our (dummy), iterating non-padded word positions in those, and for each non-padded word position, construct a window:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0, 0, 1, 2, 3, 4, 0, 0],\n",
      "        [0, 0, 5, 6, 7, 8, 0, 0]])\n"
     ]
    }
   ],
   "source": [
    "dummy_input = torch.zeros(2, 8).long()\n",
    "dummy_input[:,2:-2] = torch.arange(1,9).view(2,4)\n",
    "pp.pprint(dummy_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([2, 4, 5])\n",
      "tensor([[[0, 0, 1, 2, 3],\n",
      "         [0, 1, 2, 3, 4],\n",
      "         [1, 2, 3, 4, 0],\n",
      "         [2, 3, 4, 0, 0]],\n",
      "\n",
      "        [[0, 0, 5, 6, 7],\n",
      "         [0, 5, 6, 7, 8],\n",
      "         [5, 6, 7, 8, 0],\n",
      "         [6, 7, 8, 0, 0]]])\n"
     ]
    }
   ],
   "source": [
    "dummy_output = [[[dummy_input[i, j-2+k].item() for k in range(2*2+1)] \n",
    "                                                     for j in range(2, 6)] \n",
    "                                                            for i in range(2)]\n",
    "dummy_output = torch.LongTensor(dummy_output)\n",
    "print(dummy_output.size())\n",
    "pp.pprint(dummy_output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Technically* it works: For each element in the batch, for each word in the original sentence and ignoring window padding, we've got the 5 token indices centered at that word. But in practice will be crazy slow."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Instead, we ideally want to find the right tensor operation in the PyTorch arsenal. Here, that happens to be __Tensor.unfold__."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[0, 0, 1, 2, 3],\n",
       "         [0, 1, 2, 3, 4],\n",
       "         [1, 2, 3, 4, 0],\n",
       "         [2, 3, 4, 0, 0]],\n",
       "\n",
       "        [[0, 0, 5, 6, 7],\n",
       "         [0, 5, 6, 7, 8],\n",
       "         [5, 6, 7, 8, 0],\n",
       "         [6, 7, 8, 0, 0]]])"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dummy_input.unfold(1, 2*2+1, 1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A model in full."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In PyTorch, we implement models by extending the nn.Module class. Minimally, this requires implementing an *\\_\\_init\\_\\_* function and a *forward* function.\n",
    "\n",
    "In *\\_\\_init\\_\\_* we want to store model parameters (weights) and hyperparameters (dimensions).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SoftmaxWordWindowClassifier(nn.Module):\n",
    "    \"\"\"\n",
    "    A one-layer, binary word-window classifier.\n",
    "    \"\"\"\n",
    "    def __init__(self, config, vocab_size, pad_idx=0):\n",
    "        super(SoftmaxWordWindowClassifier, self).__init__()\n",
    "        \"\"\"\n",
    "        Instance variables.\n",
    "        \"\"\"\n",
    "        self.window_size = 2*config[\"half_window\"]+1\n",
    "        self.embed_dim = config[\"embed_dim\"]\n",
    "        self.hidden_dim = config[\"hidden_dim\"]\n",
    "        self.num_classes = config[\"num_classes\"]\n",
    "        self.freeze_embeddings = config[\"freeze_embeddings\"]\n",
    "        \n",
    "        \"\"\"\n",
    "        Embedding layer\n",
    "        -model holds an embedding for each layer in our vocab\n",
    "        -sets aside a special index in the embedding matrix for padding vector (of zeros)\n",
    "        -by default, embeddings are parameters (so gradients pass through them)\n",
    "        \"\"\"\n",
    "        self.embed_layer = nn.Embedding(vocab_size, self.embed_dim, padding_idx=pad_idx)\n",
    "        if self.freeze_embeddings:\n",
    "            self.embed_layer.weight.requires_grad = False\n",
    "        \n",
    "        \"\"\"\n",
    "        Hidden layer\n",
    "        -we want to map embedded word windows of dim (window_size+1)*self.embed_dim to a hidden layer.\n",
    "        -nn.Sequential allows you to efficiently specify sequentially structured models\n",
    "            -first the linear transformation is evoked on the embedded word windows\n",
    "            -next the nonlinear transformation tanh is evoked.\n",
    "        \"\"\"\n",
    "        self.hidden_layer = nn.Sequential(nn.Linear(self.window_size*self.embed_dim, \n",
    "                                                    self.hidden_dim), \n",
    "                                          nn.Tanh())\n",
    "        \n",
    "        \"\"\"\n",
    "        Output layer\n",
    "        -we want to map elements of the output layer (of size self.hidden dim) to a number of classes.\n",
    "        \"\"\"\n",
    "        self.output_layer = nn.Linear(self.hidden_dim, self.num_classes)\n",
    "        \n",
    "        \"\"\"\n",
    "        Softmax\n",
    "        -The final step of the softmax classifier: mapping final hidden layer to class scores.\n",
    "        -pytorch has both logsoftmax and softmax functions (and many others)\n",
    "        -since our loss is the negative LOG likelihood, we use logsoftmax\n",
    "        -technically you can take the softmax, and take the log but PyTorch's implementation\n",
    "         is optimized to avoid numerical underflow issues.\n",
    "        \"\"\"\n",
    "        self.log_softmax = nn.LogSoftmax(dim=2)\n",
    "        \n",
    "    def forward(self, inputs):\n",
    "        \"\"\"\n",
    "        Let B:= batch_size\n",
    "            L:= window-padded sentence length\n",
    "            D:= self.embed_dim\n",
    "            S:= self.window_size\n",
    "            H:= self.hidden_dim\n",
    "            \n",
    "        inputs: a (B, L) tensor of token indices\n",
    "        \"\"\"\n",
    "        B, L = inputs.size()\n",
    "        \n",
    "        \"\"\"\n",
    "        Reshaping.\n",
    "        Takes in a (B, L) LongTensor\n",
    "        Outputs a (B, L~, S) LongTensor\n",
    "        \"\"\"\n",
    "        # Fist, get our word windows for each word in our input.\n",
    "        token_windows = inputs.unfold(1, self.window_size, 1)\n",
    "        _, adjusted_length, _ = token_windows.size()\n",
    "        \n",
    "        # Good idea to do internal tensor-size sanity checks, at the least in comments!\n",
    "        assert token_windows.size() == (B, adjusted_length, self.window_size)\n",
    "        \n",
    "        \"\"\"\n",
    "        Embedding.\n",
    "        Takes in a torch.LongTensor of size (B, L~, S) \n",
    "        Outputs a (B, L~, S, D) FloatTensor.\n",
    "        \"\"\"\n",
    "        embedded_windows = self.embed_layer(token_windows)\n",
    "        \n",
    "        \"\"\"\n",
    "        Reshaping.\n",
    "        Takes in a (B, L~, S, D) FloatTensor.\n",
    "        Resizes it into a (B, L~, S*D) FloatTensor.\n",
    "        -1 argument \"infers\" what the last dimension should be based on leftover axes.\n",
    "        \"\"\"\n",
    "        embedded_windows = embedded_windows.view(B, adjusted_length, -1)\n",
    "        \n",
    "        \"\"\"\n",
    "        Layer 1.\n",
    "        Takes in a (B, L~, S*D) FloatTensor.\n",
    "        Resizes it into a (B, L~, H) FloatTensor\n",
    "        \"\"\"\n",
    "        layer_1 = self.hidden_layer(embedded_windows)\n",
    "        \n",
    "        \"\"\"\n",
    "        Layer 2\n",
    "        Takes in a (B, L~, H) FloatTensor.\n",
    "        Resizes it into a (B, L~, 2) FloatTensor.\n",
    "        \"\"\"\n",
    "        output = self.output_layer(layer_1)\n",
    "        \n",
    "        \"\"\"\n",
    "        Softmax.\n",
    "        Takes in a (B, L~, 2) FloatTensor of unnormalized class scores.\n",
    "        Outputs a (B, L~, 2) FloatTensor of (log-)normalized class scores.\n",
    "        \"\"\"\n",
    "        output = self.log_softmax(output)\n",
    "        \n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training.\n",
    "\n",
    "Now that we've got a model, we have to train it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_function(outputs, labels, lengths):\n",
    "    \"\"\"Computes negative LL loss on a batch of model predictions.\"\"\"\n",
    "    B, L, num_classes = outputs.size()\n",
    "    num_elems = lengths.sum().float()\n",
    "        \n",
    "    # get only the values with non-zero labels\n",
    "    loss = outputs*labels\n",
    "    \n",
    "    # rescale average\n",
    "    return -loss.sum() / num_elems"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_epoch(loss_function, optimizer, model, train_data):\n",
    "    \n",
    "    ## For each batch, we must reset the gradients\n",
    "    ## stored by the model.   \n",
    "    total_loss = 0\n",
    "    for batch, labels, lengths in train_data:\n",
    "        # clear gradients\n",
    "        optimizer.zero_grad()\n",
    "        # evoke model in training mode on batch\n",
    "        outputs = model.forward(batch)\n",
    "        # compute loss w.r.t batch\n",
    "        loss = loss_function(outputs, labels, lengths)\n",
    "        # pass gradients back, startiing on loss value\n",
    "        loss.backward()\n",
    "        # update parameters\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item()\n",
    "    \n",
    "    # return the total to keep track of how you did this time around\n",
    "    return total_loss\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\"batch_size\": 4,\n",
    "          \"half_window\": 2,\n",
    "          \"embed_dim\": 25,\n",
    "          \"hidden_dim\": 25,\n",
    "          \"num_classes\": 2,\n",
    "          \"freeze_embeddings\": False,\n",
    "         }\n",
    "learning_rate = .0002\n",
    "num_epochs = 10000\n",
    "model = SoftmaxWordWindowClassifier(config, len(word_2_id))\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(list(zip(train_sents, train_labels)), \n",
    "                                           batch_size=2, \n",
    "                                           shuffle=True, \n",
    "                                           collate_fn=partial(my_collate, window_size=2, word_2_id=word_2_id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1.4967301487922668, 1.408476173877716, 1.3443800806999207, 1.2865177989006042, 1.2272869944572449, 1.1691689491271973, 1.1141255497932434, 1.0696152448654175, 1.023829996585846, 0.978839099407196, 0.937132716178894, 0.8965558409690857, 0.8551942408084869, 0.8171629309654236, 0.7806291580200195, 0.7467736303806305, 0.7136902511119843, 0.6842415034770966, 0.6537061333656311, 0.6195352077484131, 0.5914349257946014, 0.5682767033576965, 0.5430445969104767, 0.5190333724021912, 0.49760693311691284, 0.47582894563674927, 0.45516568422317505, 0.4298042058944702, 0.41591694951057434, 0.39368535578250885, 0.3817802667617798, 0.36694473028182983, 0.35200121998786926, 0.3370656222105026, 0.31913231313228607, 0.3065541982650757, 0.2946578562259674, 0.28842414915561676, 0.27765345573425293, 0.26745346188545227, 0.25778329372406006, 0.24860621988773346, 0.23990143835544586, 0.22729042172431946, 0.22337404638528824, 0.21637336909770966, 0.20889568328857422, 0.20218300074338913, 0.19230441004037857, 0.19007354974746704, 0.18426819890737534, 0.17840557545423508, 0.173139289021492, 0.16499895602464676, 0.1602725237607956, 0.1590176522731781, 0.15144427865743637, 0.14732149988412857, 0.14641961455345154, 0.13959994912147522, 0.13598214834928513, 0.13251276314258575, 0.13197287172079086, 0.12871850654482841, 0.1253872662782669, 0.12239058315753937, 0.1171659529209137, 0.11695125326514244, 0.11428486183285713, 0.11171672493219376, 0.10924769192934036, 0.10686498507857323, 0.1045713983476162, 0.10218603909015656, 0.10022115334868431, 0.09602915123105049, 0.09616792947053909, 0.09424330666661263, 0.09223027899861336, 0.090587567538023, 0.08691023662686348, 0.08717184513807297, 0.08540527895092964, 0.0839710421860218, 0.08230703324079514, 0.0808291956782341, 0.07777531817555428, 0.0780084915459156, 0.07678597420454025, 0.07535399869084358, 0.07408255711197853, 0.07296567782759666, 0.07176320999860764, 0.07059716433286667, 0.0694643184542656, 0.06684627756476402, 0.06579622253775597, 0.06477534398436546, 0.06378135085105896, 0.06281331554055214]\n"
     ]
    }
   ],
   "source": [
    "losses = []\n",
    "for epoch in range(num_epochs):\n",
    "    epoch_loss = train_epoch(loss_function, optimizer, model, train_loader)\n",
    "    if epoch % 100 == 0:\n",
    "        losses.append(epoch_loss)\n",
    "print(losses)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loader = torch.utils.data.DataLoader(list(zip(test_sents, test_labels)), \n",
    "                                           batch_size=1, \n",
    "                                           shuffle=False, \n",
    "                                           collate_fn=partial(my_collate, window_size=2, word_2_id=word_2_id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0, 0, 0, 1]])\n",
      "tensor([[0, 0, 0, 1]])\n"
     ]
    }
   ],
   "source": [
    "for test_instance, labs, _ in test_loader:\n",
    "    outputs = model.forward(test_instance)\n",
    "    print(torch.argmax(outputs, dim=2))\n",
    "    print(torch.argmax(labs, dim=2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
